package danran.dbapi.plugin.loader;

import danran.dbapi.plugin.log.Logger;

import java.io.*;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * @Classname PluginClassLoader
 * @Description TODO
 * @Date 2022/2/8 20:10
 * @Created by RanCoder
 */
public class PluginClassLoader extends ClassLoader {
    /**
     * jar位于的文件夹路径
     */
    private String lib;

    /**
     * .class文件位于的文件夹路径
     */
    private String classes;

    /**
     * 从jar包中读取到的class字节流
     */
    private Map<String, byte[]> map;

    private static final Logger log = Logger.getInstance(PluginClassLoader.class);

    public PluginClassLoader(String basePath) {
        lib = basePath + "/lib/";
        classes = basePath + "/classes/";
        map = new HashMap<>();
        readJarFile();
    }

    public PluginClassLoader() {
        map = new HashMap<>();
    }

    /**
     * 按照父类的机制，如果在父类中没有找到的类
     * 才会调用这个findClass来加载
     * 这样只会加载放在自己目录下的文件
     * 而系统自带需要的class并不是由这个加载
     */
    @Override
    protected Class<?> findClass(String name) {
        try {
            byte[] result = getClassFromFileOrMap(name);
            if (result == null) {
                throw new FileNotFoundException();
            } else {
                return defineClass(name, result, 0, result.length);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 从指定的classes文件夹下找到文件
     *
     * @param name 类名
     */
    private byte[] getClassFromFileOrMap(String name) {
        String classPath = classes + name.replace('.', File.separatorChar) + ".class";
        File file = new File(classPath);
        if (file.exists()) {
            InputStream input = null;
            try {
                input = new FileInputStream(file);
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                int bufferSize = 4096;
                byte[] buffer = new byte[bufferSize];
                int bytesNumRead = 0;
                while ((bytesNumRead = input.read(buffer)) != -1) {
                    baos.write(buffer, 0, bytesNumRead);
                }
                return baos.toByteArray();
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                if (input != null) {
                    try {
                        input.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }

        } else {
            if (map.containsKey(name)) {
                //去除map中的引用，避免GC无法回收无用的class文件
                return map.remove(name);
            }
        }
        return null;
    }

    /**
     * 读取lib目录下的jar文件的信息
     */
    private void readJarFile() {
        List<File> jarFiles = scanDir();
        for (File jarFile : jarFiles) {
            JarFile jar;
            try {
                jar = new JarFile(jarFile);
                readJar(jar);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 读取一个jar包内的class文件，并存在当前加载器的map中
     *
     * @param jar jar文件
     */
    private void readJar(JarFile jar) throws IOException {
        Enumeration<JarEntry> en = jar.entries();
        while (en.hasMoreElements()) {
            JarEntry jarEntry = en.nextElement();
            String name = jarEntry.getName();
            if (name.endsWith(".class")) {
                String clazz = name.replace(".class", "").replaceAll("/", ".");
                if (this.findLoadedClass(clazz) != null) {
                    continue;
                }

                InputStream in = jar.getInputStream(jarEntry);
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                int buffSize = 4096;
                byte[] buffer = new byte[buffSize];
                int readed = 0;
                while ((readed = in.read(buffer)) != -1) {
                    baos.write(buffer, 0, readed);
                }
                byte[] cc = baos.toByteArray();
                in.close();
                map.put(clazz, cc);
                log.info("装载类：" + clazz);
            }
        }
    }

    /**
     * 扫描lib下面的所有jar包
     *
     * @return jar文件列表
     */
    private List<File> scanDir() {
        List<File> list = new ArrayList<>();
        File[] files = new File(lib).listFiles();
        if (files == null) return list;
        for (File f : files) {
            if (f.isFile() && f.getName().endsWith(".jar"))
                list.add(f);
        }
        return list;
    }

    /**
     * 添加一个jar包到加载器中去。
     *
     * @param jarPath jar包的路径
     */
    public void addJar(String jarPath) throws IOException {
        File file = new File(jarPath);
        if (file.exists()) {
            JarFile jar = new JarFile(file);
            readJar(jar);
        }
    }

    /**
     * 添加一个jar包到加载器中去。
     *
     * @param file jar包文件
     */
    public void addJar(File file) throws IOException {
        if (file.exists()) {
            JarFile jar = new JarFile(file);
            readJar(jar);
        }
    }

}
